{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install allennlp\n",
    "!git clone https://github.com/mhagiwara/realworldnlp.git\n",
    "%cd realworldnlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "from typing import Dict, List, Tuple\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from allennlp.common.file_utils import cached_path\n",
    "from allennlp.data import DatasetReader\n",
    "from allennlp.data.fields import TextField, SequenceLabelField\n",
    "from allennlp.data.instance import Instance\n",
    "from allennlp.data.iterators import BucketIterator\n",
    "from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer\n",
    "from allennlp.data.tokenizers.token import Token\n",
    "from allennlp.data.vocabulary import Vocabulary\n",
    "from allennlp.models import Model\n",
    "from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper\n",
    "from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder\n",
    "from allennlp.modules.token_embedders import Embedding\n",
    "from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits\n",
    "from allennlp.training.metrics import CategoricalAccuracy, SpanBasedF1Measure\n",
    "from allennlp.training.trainer import Trainer\n",
    "from overrides import overrides"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EMBEDDING_SIZE = 128\n",
    "HIDDEN_SIZE = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NERDatasetReader(DatasetReader):\n",
    "    def __init__(self, token_indexers: Dict[str, TokenIndexer]=None, lazy=False):\n",
    "        super().__init__(lazy=lazy)\n",
    "        self.token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}\n",
    "\n",
    "    @overrides\n",
    "    def text_to_instance(self, tokens: List[Token], labels: List[str]=None):\n",
    "        fields = {}\n",
    "\n",
    "        text_field = TextField(tokens, self.token_indexers)\n",
    "        fields['tokens'] = text_field\n",
    "        if labels:\n",
    "            fields['labels'] = SequenceLabelField(labels, text_field)\n",
    "\n",
    "        return Instance(fields)\n",
    "\n",
    "    def _convert_sentence(self, rows: List[Tuple[str]]) -> Tuple[List[Token], List[str]]:\n",
    "        \"\"\"Given a list of rows, returns tokens and labels.\"\"\"\n",
    "        _, tokens, _, labels = zip(*rows)\n",
    "        tokens = [Token(t) for t in tokens]\n",
    "\n",
    "        # NOTE: the original dataset seems to confuse gpe with geo, and the distinction\n",
    "        # seems arbitrary. Here we replace both with 'gpe'\n",
    "        labels = [label.replace('geo', 'gpe') for label in labels]\n",
    "        return tokens, labels\n",
    "\n",
    "    @overrides\n",
    "    def _read(self, file_path: str):\n",
    "\n",
    "        file_path = cached_path(file_path)\n",
    "        sentence = []\n",
    "        with open(file_path, mode='r', encoding='utf-8', errors='ignore') as csv_file:\n",
    "            next(csv_file)\n",
    "            reader = csv.reader(csv_file)\n",
    "\n",
    "            for row in reader:\n",
    "                if row[0] and sentence:\n",
    "                    tokens, labels = self._convert_sentence(sentence)\n",
    "                    yield self.text_to_instance(tokens, labels)\n",
    "\n",
    "                    sentence = [row]\n",
    "                else:\n",
    "                    sentence.append(row)\n",
    "\n",
    "            if sentence:\n",
    "                tokens, labels = self._convert_sentence(sentence)\n",
    "                yield self.text_to_instance(tokens, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LstmTagger(Model):\n",
    "    def __init__(self,\n",
    "                 embedder: TextFieldEmbedder,\n",
    "                 encoder: Seq2SeqEncoder,\n",
    "                 vocab: Vocabulary) -> None:\n",
    "        super().__init__(vocab)\n",
    "        self.embedder = embedder\n",
    "        self.encoder = encoder\n",
    "        self.hidden2labels = torch.nn.Linear(in_features=encoder.get_output_dim(),\n",
    "                                             out_features=vocab.get_vocab_size('labels'))\n",
    "        self.accuracy = CategoricalAccuracy()\n",
    "        self.f1 = SpanBasedF1Measure(vocab, tag_namespace='labels')\n",
    "\n",
    "    def forward(self,\n",
    "                tokens: Dict[str, torch.Tensor],\n",
    "                labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:\n",
    "        mask = get_text_field_mask(tokens)\n",
    "        embeddings = self.embedder(tokens)\n",
    "        encoder_out = self.encoder(embeddings, mask)\n",
    "        logits = self.hidden2labels(encoder_out)\n",
    "        output = {'logits': logits}\n",
    "        if labels is not None:\n",
    "            self.accuracy(logits, labels, mask)\n",
    "            self.f1(logits, labels, mask)\n",
    "            output['loss'] = sequence_cross_entropy_with_logits(logits, labels, mask)\n",
    "\n",
    "        return output\n",
    "\n",
    "    def get_metrics(self, reset: bool = False) -> Dict[str, float]:\n",
    "        f1_metrics = self.f1.get_metric(reset)\n",
    "        return {'accuracy': self.accuracy.get_metric(reset),\n",
    "                'prec': f1_metrics['precision-overall'],\n",
    "                'rec': f1_metrics['recall-overall'],\n",
    "                'f1': f1_metrics['f1-measure-overall']}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "reader = NERDatasetReader()\n",
    "dataset = list(reader.read('https://s3.amazonaws.com/realworldnlpbook/data/entity-annotated-corpus/ner_dataset.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(tokens: List[str], model: LstmTagger) -> List[str]:\n",
    "    token_indexers = {'tokens': SingleIdTokenIndexer()}\n",
    "    tokens = [Token(t) for t in tokens]\n",
    "    inst = Instance({'tokens': TextField(tokens, token_indexers)})\n",
    "    logits = model.forward_on_instance(inst)['logits']\n",
    "    label_ids = np.argmax(logits, axis=1)\n",
    "    labels = [model.vocab.get_token_from_index(label_id, 'labels')\n",
    "              for label_id in label_ids]\n",
    "    return labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = [inst for i, inst in enumerate(dataset) if i % 10 != 0]\n",
    "dev_dataset = [inst for i, inst in enumerate(dataset) if i % 10 == 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = Vocabulary.from_instances(train_dataset + dev_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),\n",
    "                            embedding_dim=EMBEDDING_SIZE)\n",
    "word_embeddings = BasicTextFieldEmbedder({\"tokens\": token_embedding})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm = PytorchSeq2SeqWrapper(\n",
    "    torch.nn.LSTM(EMBEDDING_SIZE, HIDDEN_SIZE, bidirectional=True, batch_first=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LstmTagger(word_embeddings, lstm, vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterator = BucketIterator(batch_size=16, sorting_keys=[(\"tokens\", \"num_tokens\")])\n",
    "iterator.index_with(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(model=model,\n",
    "                  optimizer=optimizer,\n",
    "                  iterator=iterator,\n",
    "                  train_dataset=train_dataset,\n",
    "                  validation_dataset=dev_dataset,\n",
    "                  patience=10,\n",
    "                  num_epochs=10)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = ['Apple', 'is', 'looking', 'to', 'buy', 'U.K.', 'startup', 'for', '$1', 'billion', '.']\n",
    "labels = predict(tokens, model)\n",
    "print(' '.join('{}/{}'.format(token, label) for token, label in zip(tokens, labels)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
